{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "import skimage.transform\n",
    "import matplotlib.pyplot as plt\n",
    "from easydict import EasyDict as edict\n",
    "\n",
    "from main_monodepth_pytorch import Model\n",
    "%reload_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check if CUDA is available"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.is_available()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.device_count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dict_parameters = edict({'data_dir':'data/kitti/train/',\n",
    "                         'val_data_dir':'data/kitti/val/',\n",
    "                         'model_path':'data/models/monodepth_resnet18_001.pth',\n",
    "                         'output_directory':'data/output/',\n",
    "                         'input_height':256,\n",
    "                         'input_width':512,\n",
    "                         'model':'resnet18_md',\n",
    "                         'pretrained':True,\n",
    "                         'mode':'train',\n",
    "                         'epochs':200,\n",
    "                         'learning_rate':1e-4,\n",
    "                         'batch_size': 8,\n",
    "                         'adjust_lr':True,\n",
    "                         'device':'cuda:0',\n",
    "                         'do_augmentation':True,\n",
    "                         'augment_parameters':[0.8, 1.2, 0.5, 2.0, 0.8, 1.2],\n",
    "                         'print_images':False,\n",
    "                         'print_weights':False,\n",
    "                         'input_channels': 3,\n",
    "                         'num_workers': 8,\n",
    "                         'use_multiple_gpu': False})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Model(dict_parameters)\n",
    "#model.load('data/models/monodepth_resnet18_001_last.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dict_parameters_test = edict({'data_dir':'data/test',\n",
    "                              'model_path':'data/models/monodepth_resnet18_001_cpt.pth',\n",
    "                              'output_directory':'data/output/',\n",
    "                              'input_height':256,\n",
    "                              'input_width':512,\n",
    "                              'model':'resnet18_md',\n",
    "                              'pretrained':False,\n",
    "                              'mode':'test',\n",
    "                              'device':'cuda:0',\n",
    "                              'input_channels':3,\n",
    "                              'num_workers':4,\n",
    "                              'use_multiple_gpu':False})\n",
    "model_test = Model(dict_parameters_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_test.test()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "disp = np.load('data/output/disparities_pp.npy')  # Or disparities.npy for output without post-processing\n",
    "disp.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "disp_to_img = skimage.transform.resize(disp[0].squeeze(), [375, 1242], mode='constant')\n",
    "plt.imshow(disp_to_img, cmap='plasma')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save a color image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imsave(os.path.join(dict_parameters_test.output_directory,\n",
    "                        dict_parameters_test.model_path.split('/')[-1][:-4]+'_test_output.png'), disp_to_img, cmap='plasma')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save all test images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(disp.shape[0]):\n",
    "    disp_to_img = skimage.transform.resize(disp[i].squeeze(), [375, 1242], mode='constant')\n",
    "    plt.imsave(os.path.join(dict_parameters_test.output_directory,\n",
    "               'pred_'+str(i)+'.png'), disp_to_img, cmap='plasma')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save a grayscale image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imsave(os.path.join(dict_parameters_test.output_directory,\n",
    "                        dict_parameters_test.model_path.split('/')[-1][:-4]+'_gray.png'), disp_to_img, cmap='gray')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
